In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import random
random.seed(1100038344)
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
In order to demonstrate the use of this model, we will first simulate some survival data using survivalstan.sim.sim_data_exp_correlated
. As the name implies, this function simulates data assuming a constant hazard throughout the follow-up time period, which is consistent with the Exponential survival function.
This function includes two simulated covariates by default (age
and sex
). We also simulate a situation where hazard is a function of the simulated value for sex
.
We also center the age
variable since this will make it easier to interpret estimates of the baseline hazard.
In [4]:
d = stancache.cached(
survivalstan.sim.sim_data_exp_correlated,
N=100,
censor_time=20,
rate_form='1 + sex',
rate_coefs=[-3, 0.5],
)
d['age_centered'] = d['age'] - d['age'].mean()
Aside: In order to make this a more reproducible example, this code is using a file-caching function stancache.cached
to wrap a function call to survivalstan.sim.sim_data_exp_correlated
.
Here is what these data look like - this is per-subject
or time-to-event
form:
In [5]:
d.head()
Out[5]:
It's not that obvious from the field names, but in this example "subjects" are indexed by the field index
.
We can plot these data using lifelines
, or the rudimentary plotting functions provided by survivalstan
.
In [6]:
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
plt.legend()
Out[6]:
In [7]:
model_code = '''
functions {
// Defines the log survival
vector log_S (vector t, real shape, vector rate) {
vector[num_elements(t)] log_S;
for (i in 1:num_elements(t)) {
log_S[i] = gamma_lccdf(t[i]|shape,rate[i]);
}
return log_S;
}
// Defines the log hazard
vector log_h (vector t, real shape, vector rate) {
vector[num_elements(t)] log_h;
vector[num_elements(t)] ls;
ls = log_S(t,shape,rate);
for (i in 1:num_elements(t)) {
log_h[i] = gamma_lpdf(t[i]|shape,rate[i]) - ls[i];
}
return log_h;
}
// Defines the sampling distribution
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
vector[num_elements(t)] log_lik;
real prob;
log_lik = d .* log_h(t,shape,rate) + log_S(t,shape,rate);
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector[N] linpred;
vector[N] mu;
linpred = x*beta;
for (i in 1:N) {
mu[i] = exp(linpred[i]);
}
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu);
}
'''
Now, we are ready to fit our model using survivalstan.fit_stan_survival_model
.
We pass a few parameters to the fit function, many of which are required. See ?survivalstan.fit_stan_survival_model for details.
Similar to what we did above, we are asking survivalstan
to cache this model fit object. See stancache for more details on how this works. Also, if you didn't want to use the cache, you could omit the parameter FIT_FUN
and survivalstan
would use the standard pystan functionality.
In [8]:
testfit = survivalstan.fit_stan_survival_model(
model_cohort = 'model 1',
model_code = model_code,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
In [9]:
# 0:00:40.518775 elapsed
In [10]:
survivalstan.utils.print_stan_summary([testfit], pars=['lp__', 'alpha', 'beta'])
In [11]:
model_code2 = '''
functions {
// Defines the log survival
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
vector[num_elements(t)] log_lik;
real prob;
for (i in 1:num_elements(t)) {
log_lik[i] = d[i] * (gamma_lpdf(t[i]|shape,rate[i]) - gamma_lccdf(t[i]|shape,rate[i]))
+ gamma_lccdf(t[i]|shape,rate[i]);
}
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector<lower=0>[N] mu;
{
vector[N] linpred;
linpred = x*beta;
mu = exp(linpred);
}
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu);
}
'''
In [12]:
testfit2 = survivalstan.fit_stan_survival_model(
model_cohort = 'model 2',
model_code = model_code2,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
In [13]:
# 0:00:21.081723 elapsed
In [14]:
survivalstan.utils.print_stan_summary([testfit2], pars=['lp__', 'alpha', 'beta'])
In [15]:
model_code3 = '''
functions {
// Defines the log survival
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
vector[num_elements(t)] log_lik;
real prob;
for (i in 1:num_elements(t)) {
log_lik[i] = log_mix(d[i], gamma_lpdf(t[i]|shape,rate[i]), gamma_lccdf(t[i]|shape,rate[i]));
}
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector[N] linpred;
vector[N] mu;
linpred = x*beta;
mu = exp(linpred);
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu);
}
'''
In [16]:
testfit3 = survivalstan.fit_stan_survival_model(
model_cohort = 'model 3',
model_code = model_code3,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
In [17]:
#0:00:20.284146 elapsed
In [18]:
survivalstan.utils.print_stan_summary([testfit3], pars=['lp__', 'alpha', 'beta'])
In [40]:
model_code4 = '''
functions {
int count_value(vector a, real val) {
int s;
s = 0;
for (i in 1:num_elements(a))
if (a[i] == val)
s = s + 1;
return s;
}
// Defines the log survival
real surv_gamma_lpdf (vector t, vector d, real shape, vector rate, int num_cens, int num_obs) {
vector[2] log_lik;
int idx_obs[num_obs];
int idx_cens[num_cens];
real prob;
int i_cens;
int i_obs;
i_cens = 1;
i_obs = 1;
for (i in 1:num_elements(t)) {
if (d[i] == 1) {
idx_obs[i_obs] = i;
i_obs = i_obs+1;
}
else {
idx_cens[i_cens] = i;
i_cens = i_cens+1;
}
}
print(idx_obs);
log_lik[1] = gamma_lpdf(t[idx_obs] | shape, rate[idx_obs]);
log_lik[2] = gamma_lccdf(t[idx_cens] | shape, rate[idx_cens]);
prob = sum(log_lik);
return prob;
}
}
data {
int N; // number of observations
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
int M; // number of covariates
matrix[N, M] x; // matrix of covariates (with n rows and H columns)
}
transformed data {
int num_cens;
int num_obs;
num_obs = count_value(event, 1);
num_cens = N - num_obs;
}
parameters {
vector[M] beta; // Coefficients in the linear predictor (including intercept)
real<lower=0> alpha; // shape parameter
}
transformed parameters {
vector[N] linpred;
vector[N] mu;
linpred = x*beta;
mu = exp(linpred);
}
model {
alpha ~ gamma(0.01,0.01);
beta ~ normal(0,5);
y ~ surv_gamma(event, alpha, mu, num_cens, num_obs);
}
'''
In [41]:
testfit4 = survivalstan.fit_stan_survival_model(
model_cohort = 'model 4',
model_code = model_code4,
df = d,
time_col = 't',
event_col = 'event',
formula = '~ age_centered + sex',
iter = 5000,
chains = 4,
seed = 9001,
FIT_FUN = stancache.cached_stan_fit,
drop_intercept = False,
)
In [42]:
# 0:00:06.245552 elapsed
In [43]:
survivalstan.utils.print_stan_summary([testfit4], pars=['lp__', 'alpha', 'beta'])
In [44]:
survivalstan.utils.plot_coefs([testfit, testfit2, testfit3, testfit4])
In [ ]: